{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c2b92e14-aa49-4b4c-b0c3-af07314eb8a7",
   "metadata": {},
   "source": [
    "# Supplementary material for Q12"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9622508-9b82-4d12-9918-aefc068af4a0",
   "metadata": {},
   "source": [
    "Under which circumstances are fully connected and convolutional layers equivalent?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a99e7ac1-9681-45c1-a4ef-715491d22828",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch version: 2.0.0\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "print(f\"PyTorch version: {torch.__version__}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "846a864a-76d8-4d64-a429-960bd0feb29a",
   "metadata": {},
   "source": [
    "## 1) Reference fully connected layer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3eb52f1-0e48-4423-a8dc-a0b9c380cb6d",
   "metadata": {},
   "source": [
    "<img src=\"img/fc-cnn-equivalent-1.png\" width=\"400px\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "036c0cdd-dfa4-49cf-9572-a6a727e4170f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.4775, -2.1469]])\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "\n",
    "fc = torch.nn.Linear(4, 2)\n",
    "\n",
    "inputs = torch.tensor([[1., 2., 3., 4.]])\n",
    "\n",
    "with torch.no_grad():\n",
    "    out1 = fc(inputs)\n",
    "    \n",
    "print(out1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "59e6fb1a-86cc-467c-a401-52ac19ae5744",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[-0.2039,  0.0166, -0.2483,  0.1886],\n",
       "        [-0.4260,  0.3665, -0.3634, -0.3975]], requires_grad=True)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fc.weight"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db888c2a-3f28-4e22-91ed-1f92d25a26d3",
   "metadata": {},
   "source": [
    "## 2) Scenario 1: The kernel size is equal to the input size"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "342b9ac2-56b2-4b8b-b7be-d6272b9abd52",
   "metadata": {},
   "source": [
    "<img src=\"img/fc-cnn-equivalent-2.png\" width=\"500px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f9bf351-c5d1-4b99-98ad-9d427e8205c0",
   "metadata": {},
   "source": [
    "Convolutional layers in PyTorch expect inputs on NCHW format by default, where\n",
    "\n",
    "- N = batch size\n",
    "- C = channels\n",
    "- H = height\n",
    "- W = width"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ba3dbf8d-76dc-4a1a-8f5e-f3e09db14e62",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[1., 2.],\n",
       "          [3., 4.]]]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reshaped = inputs.reshape(-1, 1, 2, 2)\n",
    "reshaped"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "12fa563f-c33e-4463-8bc3-012a0178dcaa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1, 2, 2])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv = torch.nn.Conv2d(\n",
    "    in_channels=1,\n",
    "    out_channels=2,\n",
    "    kernel_size=2\n",
    ")\n",
    "\n",
    "conv.weight.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1da7fe7e-3fb3-4211-9c50-3197901cb559",
   "metadata": {},
   "source": [
    "Note that weights in Conv2d are also initialized randomly, so to get the exact same results, we overwrite the random weights in the convolutional layer with those in the fully connected layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fd107812-7a0f-49af-84c3-2703af1d9dd5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[-0.4775]],\n",
      "\n",
      "         [[-2.1469]]]])\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    conv.weight[0][0] = fc.weight[0].reshape(1, 2, 2)\n",
    "    conv.weight[1][0] = fc.weight[1].reshape(1, 2, 2)\n",
    "    conv.bias[0] = fc.bias[0]\n",
    "    conv.bias[1] = fc.bias[1]\n",
    "    \n",
    "    out2 = conv(reshaped)\n",
    "    \n",
    "print(out2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2d55e289-a616-4128-9ec2-bd3b40616dd3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([True, True])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out1.flatten() == out2.flatten()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c830a5a-197e-4940-8de9-a6b1dae98b17",
   "metadata": {},
   "source": [
    "## 3) Scenario 2: The kernel has size one"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8aaebca-4369-474c-9244-0d6a18b42f27",
   "metadata": {},
   "source": [
    "<img src=\"img/fc-cnn-equivalent-3.png\" width=\"500px\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "de63bca7-17a2-489f-958f-41e96cfa21ed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[1.]],\n",
       "\n",
       "         [[2.]],\n",
       "\n",
       "         [[3.]],\n",
       "\n",
       "         [[4.]]]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reshaped2 = inputs.reshape(-1, 4, 1, 1)\n",
    "reshaped2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "50f69097-94c3-4d85-835d-1ae5e369c67a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 4, 1, 1])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv = torch.nn.Conv2d(\n",
    "    in_channels=4,\n",
    "    out_channels=2,\n",
    "    kernel_size=1\n",
    ")\n",
    "\n",
    "conv.weight.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b7e805d1-73de-4ecf-97df-5e7c2ba410e5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[-0.4775]],\n",
      "\n",
      "         [[-2.1469]]]])\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    conv.weight[0] = fc.weight[0].reshape(4, 1, 1)\n",
    "    conv.weight[1] = fc.weight[1].reshape(4, 1, 1)\n",
    "    conv.bias[0] = fc.bias[0]\n",
    "    conv.bias[1] = fc.bias[1]\n",
    "    \n",
    "    out3 = conv(reshaped2)\n",
    "    \n",
    "print(out3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "721f88fd-b174-4801-8ed3-be2b85a2c41f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([True, True])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out1.flatten() == out3.flatten()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
